{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 随机梯度下降法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "m = 100000\n",
    "x = np.random.normal(size=m)\n",
    "y = x * 4. + 3. + np.random.normal(0., 3., size=m)\n",
    "X = x.reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def J(X_b, y, theta):\n",
    "    try:\n",
    "        return np.sum((y - X_b.dot(theta)) ** 2) / len(X_b)\n",
    "    except:\n",
    "        return float('inf')\n",
    "    \n",
    "def dJ(X_b, y, theta):\n",
    "    return X_b.T.dot(X_b.dot(theta) - y) * 2 / m\n",
    "\n",
    "def gradient_descent(X_b, y, initial_theta, eta=0.01, epsilon = 1e-8, n_iters = 1e4):\n",
    "    theta = initial_theta\n",
    "    cur_iter = 0\n",
    "    \n",
    "    while cur_iter < n_iters:\n",
    "        gradient = dJ(X_b, y,theta)\n",
    "        last_theta = theta\n",
    "        theta -= eta * gradient\n",
    "        \n",
    "        if abs(J(X_b, y, last_theta) - J(X_b, y, theta)) < epsilon:\n",
    "            break\n",
    "            \n",
    "        cur_iter += 1\n",
    "        \n",
    "    return theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def dJ_(X_b_i, y_i, theta):\n",
    "    return X_b_i.T.dot(X_b_i.dot(theta) - y_i) * 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "shapes (100000,2) and (100000,) not aligned: 2 (dim 1) != 100000 (dim 0)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-28-17fb87e2eb4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0minitial_theta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtheta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgradient_descent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_theta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-27-560dca78ac02>\u001b[0m in \u001b[0;36mgradient_descent\u001b[0;34m(X_b, y, initial_theta, eta, epsilon, n_iters)\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mwhile\u001b[0m \u001b[0mcur_iter\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn_iters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m         \u001b[0mgradient\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdJ\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m         \u001b[0mlast_theta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mtheta\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0meta\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-26-8f0fa1553df3>\u001b[0m in \u001b[0;36mdJ\u001b[0;34m(X_b, y, theta)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdJ\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mX_b\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m: shapes (100000,2) and (100000,) not aligned: 2 (dim 1) != 100000 (dim 0)"
     ]
    }
   ],
   "source": [
    "X_b = np.hstack([np.ones((len(X), 1)), X])\n",
    "initial_theta = np.zeros(len(X))\n",
    "\n",
    "theta = gradient_descent(X_b, y, initial_theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
